def get_TPTNFNFP(logits, labels):
    """计算TP,TN,FP,FN"""
    pred_choice = logits.argmax(dim=1)
    TP = ((pred_choice == 1) & (labels == 1)).cpu().sum()
    # TN    predict 和 label 同时为0
    TN = ((pred_choice == 0) & (labels == 0)).cpu().sum()
    # FN    predict 0 label 1
    FN = ((pred_choice == 0) & (labels == 1)).cpu().sum()
    # FP    predict 1 label 0
    FP = ((pred_choice == 1) & (labels == 0)).cpu().sum()
    return TP, TN, FN, FP
